# -*- coding: utf-8 -*-
"""
Created on Wed Dec 22 16:06:48 2021

"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
import numpy as np
from torch.autograd import Variable
import time
import pandas as pd

import networkx as nx
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
from tensorboardX import SummaryWriter
from datetime import datetime
from torch.nn.utils import clip_grad_norm_
import copy
import random
import pickle
import os

from models import PhyCRNet,MaskedMSELoss,MaskedL1Loss,weightMSELoss
from data_utils import PrepareConcDataset
from utils import plot_conc,read_logs


#%%
args={}
args['cuda'] = torch.cuda.is_available()
#if cuda is available use cuda
if args['cuda']:
    device = torch.device('cuda')


experis=[f'exp{i}/' for i in range(2,4)]
experis=['exp8/']
for j in range(len(experis)):
    best_mae=100
    best_model_path=0
    
    #
    loot_path='./exp/'
    loot_path=loot_path+experis[j]
    log_path='log_conc/'
    # dirs={
    #       "log_dirs":log_dirs,
    #       "model_dirs":model_dirs,
    #       "args_dirs":args_dirs,
    #       "events_dirs":events_dirs,
    #       "orders":orders
    #       }
    
    # index=4
    dirs=read_logs(loot_path+log_path)
    
    #%%
    for index in range(len(dirs['orders'])):
            
        with open(dirs['args_dirs'][index],'rb') as f:
            args=pickle.load(f)
            
        train_input,train_label1,train_label1_coe,train_label2,train_label2_mask,\
        val_input,val_label1,val_label1_coe,val_label2,val_label2_mask,\
        test_input,test_label1,test_label1_coe,test_label2,test_label2_mask,\
        mmin_list,dev_list,conc_name,edge_list,time_scaler=PrepareConcDataset(args['time_window'],args['res'],
                                                                              args['pred_day'],args['buffer'],
                                                                              args['gap'],args['test_day'],
                                                                              args['val_ratio'])
          
        [batch_size,time_len,input_channel,h,w] = train_input.size()
        model_path=dirs['model_dirs'][index]
        
        model = PhyCRNet(args['input_dim'],
                         args['hidden_dim'],
                         args['output_dim'],
                         args['input_kernel_size'],
                         args['input_stride'],
                         args['padding_mode'],
                         args['num_layers'], 
                         args['dropout'],
                         args['warmup_updates'],
                         args['tot_updates'],
                         args['peak_lr'],
                         args['end_lr'],
						 args['power'],
                         args['weight_decay'])
        
        model=model.to(device=device)
        model.load_state_dict(torch.load(model_path))
        all_mae=plot_conc(args,model,dirs['log_dirs'][index],device,'sig')

        if all_mae<best_mae:
            best_mae=all_mae
            best_model_path=model_path
    
        weight_list = []
        bias_list=[]
        for name, param in model.named_parameters():
            if 'weight' in name:
                weight = (name, param)
                weight_list.append(weight)
            
            if 'bias' in name:
                bias=(name,param)
                bias_list.append(bias)
        
        total_weight=0
        total_bias=0
        for i in range(len(weight_list)):
            total_weight=total_weight+weight_list[i][1].numel()
            
        for i in range(len(bias_list)):
            total_bias=total_bias+bias_list[i][1].numel()
        print(f'the total number of weight is {total_weight}')
        print(f'the total number of bias is {total_bias}')
    
    model.load_state_dict(torch.load(best_model_path))
    all_mae=plot_conc(args,model,loot_path,device,'best')

